import torch
from data.load_synthetic import load_blender_data, poses2pytorch3dcam

class BlenderDataset(torch.utils.data.Dataset):
    def __init__(self, split='train', epoch_len=1000, **synthetic_args):
        images, poses, render_poses, [H, W, focal], i_split = load_blender_data(**synthetic_args)
        train_ids, val_ids, test_ids = i_split
        if split == 'train':
            self.target_images = torch.from_numpy(images[train_ids][..., :3]).float()
            self.target_silhouettes = torch.from_numpy(images[train_ids][..., 3]).float()
            self.target_cameras = poses2pytorch3dcam(poses[train_ids], H, W, focal)
            self.epoch_len = epoch_len
        elif split == 'val':
            self.target_images = torch.from_numpy(images[val_ids][..., :3]).float()
            self.target_silhouettes = torch.from_numpy(images[val_ids][..., 3]).float()
            self.target_cameras = poses2pytorch3dcam(poses[val_ids], H, W, focal)
            self.epoch_len = len(self.target_cameras)
        else:
            self.target_images = torch.from_numpy(images[test_ids][..., :3]).float()
            self.target_silhouettes = torch.from_numpy(images[test_ids][..., 3]).float()
            self.target_cameras = poses2pytorch3dcam(poses[test_ids], H, W, focal)
            self.epoch_len = len(self.target_cameras)
    
    def __getitem__(self, i):
        return {
            'target_camera' : self.target_cameras[i % len(self.target_cameras)], 
            'target_image' : self.target_images[i % len(self.target_cameras)], 
            'target_silhouette' : self.target_silhouettes[i % len(self.target_images)]
        }
    
    def __len__(self):
        return self.epoch_len